from full_system_utils import *
from experiments_utils import *

import random



class DoppelWriter:
    def __init__(self, data_folder, delta_folder, vae_path, vae_info_path, if_pca=False, pca_folder=None):
        print("Loading data...")
        self.load_data(data_folder, delta_folder)
        print("Loading VAE...")
        self.load_vae(vae_path)
        print("Loading model...")
        self.load_llm()

        with open(vae_info_path, 'r') as file:
            self.vae_info = json.load(file)
        self.BASE_MODELS = self.vae_info["base_models"]

        if if_pca:
            def preprocess(raw_delta):
                return vae_preprocess(pca_preprocess(raw_delta, self.pca_by_layer))
            def postprocess(delta_vae_out):
                return pca_postprocess(vae_postprocess(delta_vae_out), self.pca_by_layer)
            
            print("Loading PCA...")
            self.load_pca(pca_folder)

            self.preprocess_fn = preprocess
            self.postprocess_fn = postprocess
        else:
            self.preprocess_fn = vae_preprocess
            self.postprocess_fn = vae_postprocess

        print("Loaded!!")

    def load_data(self, data_folder, delta_folder):
        delta_list = []
        self.test_data_list = []
        self.train_data_list = []
        self.folder_name_list = []
        layer_names = None
        og_shapes = None
        to_remove = "wallstreetbets_CHAINSAW_VASECTOMY"

        for folder_name in sorted(os.listdir(data_folder)):
            if folder_name == to_remove:
                continue
            curr_data_folder = os.path.join(data_folder, folder_name)
            curr_delta_folder = os.path.join(delta_folder, folder_name)
            self.folder_name_list.append(folder_name)
            delta_file = [file for file in os.listdir(curr_delta_folder) if '.pkl' in file][0]
            data_file = [file for file in os.listdir(curr_data_folder) if 'test' in file][0]
            train_data_file = [file for file in os.listdir(curr_data_folder) if 'train' in file][0]
#            print(folder_name)
            d = pickle.load(open(os.path.join(curr_delta_folder, delta_file), 'rb'))
            loaded_tensor_list = list(d.values())

            if layer_names is None:
                layer_names = d.keys()
            else:
                assert(layer_names == d.keys())

            if og_shapes is None:
               og_shapes = [t.shape for t in loaded_tensor_list]

            delta = torch.stack([t.reshape(delta_layer_shape[1:]) for t in loaded_tensor_list], axis=0)
    
            test = open(os.path.join(curr_data_folder, data_file), 'r').read().split('\n\n')[:-1]
            train = open(os.path.join(curr_data_folder, train_data_file), 'r').read().split('\n\n')[:-1]
            self.test_data_list.append(test)
            self.train_data_list.append(train)
            delta_list.append(delta)
        self.raw_data = torch.stack(delta_list, axis=0)
        self.layer_names = list(layer_names)
        self.og_shapes = og_shapes

    def load_llm(self):

        # Load tokenizer and model with QLoRA configuration
        compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=use_4bit,
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=use_nested_quant,
        )

        # Check GPU compatibility with bfloat16
        if compute_dtype == torch.float16 and use_4bit:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                print("=" * 80)
                print("Your GPU supports bfloat16: accelerate training with bf16=True")
                print("=" * 80)

        # Load base model
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map=device_map
        )
        model.config.use_cache = False
        model.config.pretraining_tp = 1

        # Load LLaMA tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

        self.peft_config = LoraConfig(
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            r=lora_r,
            bias="none",
            task_type="CAUSAL_LM",
        )

        self.training_arguments = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_train_epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim=optim,
            save_steps=save_steps,
            logging_steps=logging_steps,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            fp16=fp16,
            bf16=bf16,
            max_grad_norm=max_grad_norm,
            max_steps=max_steps,
            warmup_ratio=warmup_ratio,
            group_by_length=group_by_length,
            lr_scheduler_type=lr_scheduler_type,
        )

        # Set supervised fine-tuning parameters
        trainer = SFTTrainer(
            model=model,
            dataset_text_field="text",
            peft_config = self.peft_config,
            max_seq_length=max_seq_length,
            tokenizer=tokenizer,
            args=self.training_arguments,
            packing=packing,
        )

        self.og_weights = []
        for name, param in model.named_parameters():
            if 'lora' in name:
                self.og_weights.append(param.clone())
        
        self.model = model
        self.tokenizer = tokenizer

    def load_pca(self, pca_folder):
        pca_by_layer = []
        for file_name in sorted(os.listdir(pca_folder), key=lambda x: int(x[:x.find('_')])):
            pca_by_layer.append(load(os.path.join(pca_folder, file_name)))
        self.pca_by_layer = pca_by_layer

    def load_vae(self, vae_path):
        vae = MultiheadVariationalAutoencoder(input_size, output_shape, latent_dims, hidden_dims, n_inputs).cpu()
        vae.load_state_dict(torch.load(vae_path))
        self.vae = vae

#    def get_candidate_latent_models(self, sample_corpus, n_base_models, rand_models, model_ids=None, bases=None, new_interpolation=False, interpolation_privilege=None):
#
#        base_models, one_step_deltas = self._backprob(sample_corpus, n_base_models, rand_models, model_ids, bases)
#        
#        candidate_latent_models, lines, pos_intersections = self._interpolate(base_models, one_step_deltas, new_interpolation, interpolation_privilege)
#
#        return candidate_latent_models, one_step_deltas, lines, pos_intersections
#    
#    def _interpolate(self, base_models, one_step_deltas, new_interpolation, interpolation_privilege=None):
#
#        n_base_models = len(base_models)
#        diffs = []
#        for i in range(len(one_step_deltas)):
#            base_mu, base_sigma, diff_mu, diff_sigma = get_encoder_diff(self.vae, base_models[i], one_step_deltas[i].detach(), self.preprocess_fn)
#            diffs.append((base_mu, base_sigma, diff_mu, diff_sigma))
#        
#        lines = np.array([(diff[0].detach(), diff[2].detach()) for diff in diffs])
##        print("LINES:", lines)
#        # Find intersection points
#        intersection_points, errors = self._find_intersection_points(lines, interpolation_privilege)
#
#        candidate_latent_models = [torch.tensor(a) for a in intersection_points]
#
#        if new_interpolation and len(candidate_latent_models) > n_base_models:
#            print("N new models before", len(candidate_latent_models))
#            candidate_latent_models = self.least_distance_combination(candidate_latent_models, n_base_models)
        
 #       return candidate_latent_models, lines, errors


    def decode_latent_model(self, latent_tensor):
        delta_pca_out = self.vae.decoder(latent_tensor)
        delta_out = self.postprocess_fn(delta_pca_out)
        return delta_out
    
    def get_tuned_model(self, delta):
        new_model = apply_delta(self.model, self.layer_names, back_to_og_shapes(delta, self.og_shapes))
        return new_model


    def _find_intersection_points(self, lines, interpolation_privilege=None):
        intersection_points = []
 #       errors = []
        pos_intersection = []
        if interpolation_privilege is None:
            all_lines_comb = combinations(lines, 2)
        else:
            all_lines_comb = []
            for line in lines:
                if (line != lines[interpolation_privilege]).any():
                    all_lines_comb.append((line, lines[interpolation_privilege]))
            
        found_solution = []
        for line1, line2 in all_lines_comb:
            x1, y1 = line1
            x2, y2 = line2
            A = np.vstack((y1, -y2)).T
            b = x2 - x1
            t = np.linalg.lstsq(A, b, rcond=None)
            t_value = t[0]
            t_error = t[1]
            #t = np.linalg.solve(A, b)
            if np.all(t_value >= 0):
                pos_intersection.append(np.all(t_value >= 0))
                intersection_point1 = x1 + t_value[0]*y1
                intersection_point2 = x2 + t_value[1]*y2
                intersection_point = (intersection_point1 + intersection_point2) / 2
                intersection_points.append(intersection_point)
                found_solution.append(True)
            else:
                found_solution.append(False)
 #               errors.append(t_error)

        return intersection_points, found_solution #, errors      

    def get_line(self, base_model, one_step_delta):
        base_mu, _, diff_mu, _ = get_encoder_diff(self.vae, base_model, one_step_delta.detach(), self.preprocess_fn)
        line = (base_mu.detach(), diff_mu.detach())
        return line
    
    def get_lines(self, model_names, one_step_deltas):
        base_models = [self.raw_data[self.folder_name_list.index(model_name)] for model_name in model_names]
        lines = []
        for i in range(len(one_step_deltas)):
            line = self.get_line(base_models[i], one_step_deltas[i])
            lines.append(line)
        return np.array(lines)
        

    def interpolate(self, model_names, os_deltas, n_samples_id, int_type):
        if int_type == "vanilla_linear":
            lines = self.get_lines(model_names, os_deltas[n_samples_id])
            intersection_points, found_solution = self._find_intersection_points(lines)
            latent_models = [torch.tensor(a) for a in intersection_points]

        elif int_type == "averaging":
            changes = []
            n_models = len(model_names)
            for i in range(n_samples_id+1):
                one_step_deltas = os_deltas[i]
                lines = self.get_lines(model_names, one_step_deltas)
                if len(changes) == 0:
                    for k in range(n_models):
                        changes.append([lines[k][1]])
                else:
                    for k in range(n_models):
                        changes[k].append(lines[k][1])
            new_lines = []
            for k in range(n_models):
                new_lines.append((lines[k][0], torch.mean(torch.tensor(changes[k]), axis=0)))
            lines = [torch.tensor(line) for line in new_lines]
            intersection_points, found_solution = self._find_intersection_points(lines)
            latent_models = [torch.tensor(a) for a in intersection_points]
        
        return latent_models, found_solution


#    def _backprob(self, sample_corpus, n_base_models, rand_models, model_ids=None, bases=None):
#
#        print("Performing one-step backprob...")
#        
#        base_models = []
#
#        if rand_models:
#
#            idxes = random.sample(range(self.raw_data.shape[0]), n_base_models)
#            print("MODEL IDXES", idxes)
#
#            for idx in idxes:
#                base_models.append(self.raw_data[idx])
#
#        elif model_ids is not None:
#
#            for idx in model_ids:
#                base_models.append(self.raw_data[idx])
#
#        elif bases is not None:
#                
#                base_models = bases
#
#        else:
#
#            base_models = self.raw_data[:n_base_models]
#
#        one_step_deltas = []
#        for base_delta in base_models:
#            one_step_delta = self._backprob_one_step(sample_corpus, self.model, base_delta)
#            one_step_deltas.append(one_step_delta)
#        
#        return base_models, one_step_deltas

    def backprob_one_model(self, sample_corpus, model_name):
        model_id = self.folder_name_list.index(model_name)
        base_model = self.raw_data[model_id]
        one_step_delta = self._backprob_one_step(sample_corpus, self.model, base_model)
        return base_model, one_step_delta
    

    def _backprob_one_step(self, corpus, model, base_delta):
        dataset_dict = {
        "text": corpus
        }

        # Convert the dictionary to a Hugging Face Dataset object
        train_dataset = Dataset.from_dict(dataset_dict)
        
        random_state = random.getstate()

        trainer = SFTTrainer(
            model=model,
            train_dataset=train_dataset,
            dataset_text_field="text",
            peft_config = self.peft_config,
            max_seq_length=max_seq_length,
            tokenizer=self.tokenizer,
            args=self.training_arguments,
            packing=packing,
        )

        random.setstate(random_state)

        model = apply_delta(model, self.layer_names, back_to_og_shapes(base_delta, self.og_shapes))

        trainer.train()

        diff_list = []
        for name, param in model.named_parameters():
            if 'lora' in name:
                diff_list.append(param.clone())

        delta = torch.stack([t.reshape(delta_layer_shape[1:]) for t in diff_list], axis=0)

        return delta
    
    def represent_in_latent_space(self, model_names):
        out_list = []
        for model_name in model_names:
            id = self.folder_name_list.index(model_name)
            delta = self.raw_data[id]
            encoder_outputs = []
            for _ in range(25):
                latent = self.vae.encoder(self.preprocess_fn(delta))
                encoder_outputs.append(latent.cpu().detach())
            out_list.append(encoder_outputs)
        return out_list
    
    def get_latent_space_params(self, model_names):
        out_list = []
        for model_name in model_names:
            id = self.folder_name_list.index(model_name)
            delta = self.raw_data[id]
            self.vae.encoder(self.preprocess_fn(delta))
            out_list.append((self.vae.encoder.mu, self.vae.encoder.sigma, self.vae.encoder.kl))
        return out_list

    
    def least_distance_combination(self, tensors, k):
        """
        Find the combination of k tensors with the least Euclidean distance between them.
    
        Args:
        - tensors (list of numpy arrays): List of N tensors, each of t-dimensional.
        - k (int): Size of combinations to consider.
    
        Returns:
        - least_distance_comb (list of numpy arrays): Combination of k tensors with the least Euclidean distance between them.
        """
        # Generate all possible combinations of size k
        combinations = list(itertools.combinations(tensors, k))
    
        # Initialize variables to store the minimum distance and the corresponding combination
        min_distance = float('inf')
        least_distance_comb = None
    
        # Iterate over each combination
        for comb in combinations:
            # Calculate the total Euclidean distance within the combination
            total_distance = sum(euclidean_distance(comb[i], comb[j]) for i in range(k) for j in range(i + 1, k))
    
            # Update the minimum distance and the corresponding combination if needed
            if total_distance < min_distance:
                min_distance = total_distance
                least_distance_comb = comb
    
        return least_distance_comb
    
